import numpy as np
from math import *
import define_experiment as exp
import scipy.interpolate as sci
from scipy.integrate import trapz, solve_ivp

rover = {'wheel_assembly': {'wheel': {'radius': 0.3, 'mass': 1},
                            'speed_reducer': {'type': 'reverted', 'diam_pinion': 0.05, 'diam_gear': 0.08, 'mass': 1.50},
                            'motor': {'torque_stall': 175, 'torque_noload': 0, 'speed_noload': 3.9, 'mass': 5,
                                      'effcy_tau': np.array([0, 10, 20, 40, 70, 165]),
                                      'effcy': np.array([0, 0.55, 0.75, 0.71, 0.5, 0.05])}},
         'chassis': {'mass': 659}, 'science_payload': {'mass': 75}, 'power_subsys': {'mass': 90}, 'telemetry': {}}
planet = {'g': 3.72}
experiment1 = exp.experiment1()
experiment = experiment1[0]
end_event = experiment1[1]


def tau_dcmotor(omega, motor):
    # Check if motor is a dictionary
    if (type(motor) is dict) == False:
        raise Exception('Motor input must be a dictionary')

    # Call the required constants from the motor dictionary
    tau_s = motor['torque_stall']
    w_NL = motor['speed_noload']
    tau_NL = motor['torque_noload']

    # Check if the values of omega are a 1D numpy array or integer
    if ((isinstance(omega, (int, float))) == False) and ((type(omega) is np.ndarray) == False):
        raise Exception('Omega input must be a float or a 1D numpy array')

    # If they are an integer run the equation normally
    if isinstance(omega, (int, float)):
        tau = tau_s - ((tau_s - tau_NL) / w_NL) * omega

        # Check for if omega is outside of the acceptable range inside the integer if statement
        if omega < 0:
            tau = tau_s
        if omega > w_NL:
            tau = 0

    # If they are a vector then iterate through the array and adds a new value of tau for every value of omega in the
    # array
    w = omega
    if type(w) is np.ndarray:
        if np.ndim(w) == 1:
            tau = np.array([])
            for i in range(len(w)):
                omega = w[i].item()
                tau_value = tau_s - ((tau_s - tau_NL) / w_NL) * omega

                # Check for if omega is outside of range inside the for loop
                if omega < 0:
                    tau_value = tau_s
                if omega > w_NL:
                    tau_value = 0

                # Append the tau values for each value of omega
                tau = np.append(tau, tau_value)

        else:
            raise Exception('Omega input must be a 1D numpy array')

    # Return tau
    return tau


def get_gear_ratio(speed_reducer):
    # Checks if speed_reducer is a dictionary
    if (type(speed_reducer) is dict) == False:
        raise Exception('Speed_reducer input must be a dictionary')

    # Checks if the type is reverted
    if (speed_reducer['type'].lower()) != 'reverted':
        raise Exception('The type is not reverted')

    # Solves for Ng
    d1 = speed_reducer['diam_pinion']
    d2 = speed_reducer['diam_gear']
    Ng = (d2 / d1) ** 2

    # Returns Ng
    return Ng


def get_mass(rover):
    # Checks if rover is a dictionary
    if (type(rover) is dict) == False:
        raise Exception('Rover input must be a dictionary')

    # Calls all the values from the dictionary
    wheel_mass = rover['wheel_assembly']['wheel']['mass']
    speed_r_mass = rover['wheel_assembly']['speed_reducer']['mass']
    motor_mass = rover['wheel_assembly']['motor']['mass']
    chassis = rover['chassis']['mass']
    science_payload = rover['science_payload']['mass']
    power_subsys = rover['power_subsys']['mass']

    # Combines the mass values
    wheel_assem = 6 * (wheel_mass + speed_r_mass + motor_mass)
    m = chassis + science_payload + power_subsys + wheel_assem

    # Returns the mass
    return m


def F_drive(omega, rover):
    # Check if rover is a dictionary
    if (type(rover) is dict) == False:
        raise Exception('Rover input must be a dictionary')

    # Calls the required constants
    r = rover['wheel_assembly']['wheel']['radius']

    # Check if omega is a scalar or 1D numpy array
    if ((isinstance(omega, (int, float))) == False) and ((type(omega) is np.ndarray) == False):
        raise Exception('Omega input must be a float or a 1D numpy array')

    # If it is an integer run the equation normally
    if isinstance(omega, (int, float)):
        tau = tau_dcmotor(omega, rover['wheel_assembly']['motor']) * get_gear_ratio(rover['wheel_assembly'][
                                                                                        'speed_reducer'])  # Ask if this is the correct formula to get the tau for this question
        Fd = (tau / r) * 6

    # If it is a vector, iterate it through the vector
    w = omega
    if type(w) is np.ndarray:
        if np.ndim(w) == 1:
            Fd = np.array([])
            for i in range(len(w)):
                omega = w[i].item()
                tau = tau_dcmotor(omega, rover['wheel_assembly']['motor']) * get_gear_ratio(rover['wheel_assembly'][
                                                                                                'speed_reducer'])  # Ask if this is the correct formula to get the tau for this question
                Fd_value = (tau / r) * 6

                # Append the new value of Fd to the end of the list
                Fd = np.append(Fd, Fd_value)

        # Raise an exception if it is not 1 Dimensional
        else:
            raise Exception('Omega input must be a 1D numpy array')

    # Return Fd
    return Fd


def F_gravity(terrain_angle, rover, planet):
    # Check if rover is a dictionary
    if (type(rover) is dict) == False:
        raise Exception('Rover input must be a dictionary')

    # Check if planet is a dictionary
    if (type(planet) is dict) == False:
        raise Exception('Planet input must be a dictionary')

    # Check if terrain angle is a scalar or vector
    if (isinstance(terrain_angle, (int, float)) == False) and ((type(terrain_angle) is np.ndarray) == False):
        raise Exception('Terrain angle input must be a float or a 1D numpy array')

    # If it is an integer run the equation normally
    if isinstance(terrain_angle, (int, float)):
        # Checks if terrain angle is in the required range for a single integer
        if (terrain_angle < -75) or (terrain_angle > 75):
            raise Exception('Terrain angle is out of range')
        terrain_angle = terrain_angle * (pi / 180)
        Fgt = -(get_mass(rover) * planet['g'] * sin(terrain_angle))

    # If they are a vector then iterate through the array and adds a new value of tau for every value of omega in the
    # array
    theta = terrain_angle
    if type(theta) is np.ndarray:
        if np.ndim(theta) == 1:
            Fgt = np.array([])
            for i in range(len(theta)):
                terrain_angle = theta[i].item()
                # Checks each value in the vector to see if they are in range
                if (terrain_angle < -75) or (terrain_angle > 75):
                    raise Exception('Terrain angle is out of range')
                terrain_angle = terrain_angle * (pi / 180)
                Fgt_value = -(get_mass(rover) * planet['g'] * sin(terrain_angle))

                # Append the new value of Fd to the end of the list
                Fgt = np.append(Fgt, Fgt_value)

        else:
            raise Exception('Omega input must be a 1D numpy array')

    # Return Fgt
    return Fgt


def F_rolling(omega, terrain_angle, rover, planet, Crr):
    # Check if rover is a dictionary
    if (type(rover) is dict) == False:
        raise Exception('Rover input must be a dictionary')

    # Check if planet is a dictionary
    if (type(planet) is dict) == False:
        raise Exception('Planet input must be a dictionary')

    # Check if terrain angle is a scalar or vector
    if (isinstance(terrain_angle, (int, float)) == False) and ((type(terrain_angle) is np.ndarray) == False):
        raise Exception('Terrain angle input must be a float or a 1D numpy array')

    # Check if the values of omega are a 1D numpy array or integer
    if ((isinstance(omega, (int, float))) == False) and ((type(omega) is np.ndarray) == False):
        raise Exception('Omega input must be a float or a 1D numpy array')

    # Check if omega and terrain angle are both scalars or both vectors
    if (type(omega) is np.ndarray) != (
            type(terrain_angle) is np.ndarray):  # must be after the check for if both of them are scalars or vectors
        raise Exception('Terrain angle and Omega must be the same type')

    # Checks if omega and terrain_angle are the same size
    if (type(omega) is np.ndarray):
        if len(omega) != len(terrain_angle):
            raise Exception('Terrain angle vector and omega vector must be the same size')

    # Check if Crr is a positive scalar
    if ((isinstance(Crr, (int, float))) == False):  # Check for if it is a scalar
        raise Exception('Crr must be a positive scalar')
    else:  # If it is a scalar this is the check that it is positive
        if Crr < 0:
            raise Exception('Crr must be a positive scalar')

    # If it is an integer run the equation normally
    if isinstance(terrain_angle, (int, float)):
        # Checks if terrain angle is in the required range for a single integer
        if (terrain_angle < -75) or (terrain_angle > 75):
            raise Exception('Terrain angle is out of range')
        terrain_angle = terrain_angle * (pi / 180)
        r = rover['wheel_assembly']['wheel']['radius']
        omega = omega / get_gear_ratio(rover['wheel_assembly']['speed_reducer'])
        v_rover = r * omega
        Fn = get_mass(rover) * planet['g'] * abs(cos(terrain_angle))
        Frr_simple = Crr * Fn
        Frr = -(erf(40 * v_rover) * Frr_simple)

    # If the inputs are vectors then we have to iterate it
    theta = terrain_angle
    w = omega
    if type(theta) is np.ndarray:
        if np.ndim(theta) == 1:
            Frr = np.array([])
            for i in range(len(theta)):
                terrain_angle = theta[i].item()
                omega = w[i].item()
                # Checks each value in the vector to see if they are in range
                if (terrain_angle < -75) or (terrain_angle > 75):
                    raise Exception('Terrain angle is out of range')
                terrain_angle = terrain_angle * (pi / 180)
                r = rover['wheel_assembly']['wheel']['radius']
                omega = omega / get_gear_ratio(rover['wheel_assembly']['speed_reducer'])
                v_rover = r * omega
                Fn = get_mass(rover) * planet['g'] * abs(cos(terrain_angle))
                Frr_simple = Crr * Fn
                Frr_value = -(erf(40 * v_rover) * Frr_simple)

                # Append to the original Frr array
                Frr = np.append(Frr, Frr_value)

        else:
            raise Exception('Omega input must be a 1D numpy array')

    return Frr


def F_net(omega, terrain_angle, rover, planet, Crr):
    # Check if rover is a dictionary
    if (type(rover) is dict) == False:
        raise Exception('Rover input must be a dictionary')

    # Check if planet is a dictionary
    if (type(planet) is dict) == False:
        raise Exception('Planet input must be a dictionary')

    # Check if terrain angle is a scalar or vector
    if (isinstance(terrain_angle, (int, float)) == False) and ((type(terrain_angle) is np.ndarray) == False):
        raise Exception('Terrain angle input must be a float or a 1D numpy array')

    # Check if the values of omega are a 1D numpy array or integer
    if ((isinstance(omega, (int, float))) == False) and ((type(omega) is np.ndarray) == False):
        raise Exception('Omega input must be a float or a 1D numpy array')

    # Check if omega and terrain angle are both scalars or both vectors
    if (type(omega) is np.ndarray) != (
            type(terrain_angle) is np.ndarray):  # must be after the check for if both of them are scalars or vectors
        raise Exception('Terrain angle and Omega must be the same type')

    # Checks if omega and terrain_angle are the same size
    if (type(omega) is np.ndarray):
        if len(omega) != len(terrain_angle):
            raise Exception('Terrain angle vector and omega vector must be the same size')

    # Check if Crr is a positive scalar
    if ((isinstance(Crr, (int, float))) == False):  # Check for if it is a scalar
        raise Exception('Crr must be a positive scalar')
    else:  # If it is a scalar this is the check that it is positive
        if Crr < 0:
            raise Exception('Crr must be a positive scalar')

    # If it is an integer run the equation normally
    if isinstance(terrain_angle, (int, float)):
        # Checks if terrain angle is in the required range for a single integer
        if (terrain_angle < -75) or (terrain_angle > 75):
            raise Exception('Terrain angle is out of range')
        Fnet = F_drive(omega, rover) + F_gravity(terrain_angle, rover, planet) + F_rolling(omega, terrain_angle, rover,
                                                                                           planet, Crr)

    # If the inputs are vectors then we have to iterate it
    theta = terrain_angle
    w = omega
    if type(theta) is np.ndarray:
        if np.ndim(theta) == 1:  # Have to check if it a 1D array
            Fnet = np.array([])
            for i in range(len(theta)):
                terrain_angle = theta[i].item()
                omega = w[i].item()
                # Checks each value in the vector to see if they are in range
                if (terrain_angle < -75) or (terrain_angle > 75):
                    raise Exception('Terrain angle is out of range')
                Fnet_value = F_drive(omega, rover) + F_gravity(terrain_angle, rover, planet) + F_rolling(omega,
                                                                                                         terrain_angle,
                                                                                                         rover, planet,
                                                                                                         Crr)

                Fnet = np.append(Fnet, Fnet_value)

        else:  # If it is not a 1D array then an exception will be raised
            raise Exception('Omega input must be a 1D numpy array')

    return Fnet


def motorW(v, rover):
    # Check if rover is a dictionary
    if (type(rover) is dict) == False:
        raise Exception('Rover input must be a dictionary')

    # Check if v input is a scalar or vector
    if (isinstance(v, (int, float)) == False) and ((type(v) is np.ndarray) == False):
        raise Exception('Velocity input must be a float or a 1D numpy array')

    # If it is an integer then run the equation normally
    if isinstance(v, (int, float)):
        r = rover['wheel_assembly']['wheel']['radius']
        Ng = get_gear_ratio(rover['wheel_assembly']['speed_reducer'])
        w_out = v / r
        w = w_out * Ng

    # If the input is a vector then iterate over the entire vector
    if type(v) is np.ndarray:
        if np.ndim(v) == 1:
            w = np.array([])
            for i in range(len(v)):
                r = rover['wheel_assembly']['wheel']['radius']
                Ng = get_gear_ratio(rover['wheel_assembly']['speed_reducer'])
                v_val = v[i].item()
                w_out = v_val / r
                w_val = w_out * Ng
                w = np.append(w, w_val)

        else:
            raise Exception('Velocity Input must be a 1D array')

    return w


def rover_dynamics(t, y, rover, planet, experiment):
    # Check if rover is a dictionary
    if (type(rover) is dict) == False:
        raise Exception('Rover input must be a dictionary')

    # Check if planet is a dictionary
    if (type(planet) is dict) == False:
        raise Exception('Planet input must be a dictionary')

    # Check if experiment is a dictionary
    if (type(experiment) is dict) == False:
        raise Exception('Experiment input must be a dictionary')

    # Check if time is a scalar
    if (isinstance(t, (int, float)) == False):
        raise Exception('Time input must be a scalar')

    # Check if y is a two element numpy array
    if ((type(y) is np.ndarray) == False):
        raise Exception('Dependent variables must be a vector')
    else:  # These checks will be run if it passes the vector check
        if np.ndim(y) != 1:  # Check if it is 1D
            raise Exception('Input for dependent variables must be 1 Dimensional')
        if len(y) != 2:  # Check if it only has two elements
            raise Exception('Input for dependent variables must have only two elements')

    velocity = y[0].item()
    mass = get_mass(rover)
    alpha_fun = sci.interp1d(experiment['alpha_dist'], experiment['alpha_deg'], kind='cubic', fill_value='extrapolate')
    omega = motorW(velocity, rover)
    terrain_angle = alpha_fun(y[1]).item()
    Fn = F_net(omega, terrain_angle, rover, planet, experiment['Crr'])
    acc = (1 / mass) * Fn
    dydt = [acc, velocity]

    return dydt

def mechpower(v, rover):
    # Check if rover is a dictionary
    if (type(rover) is dict) == False:
        raise Exception('Rover input must be a dictionary')

    # Check if v input is a scalar or vector
    if (isinstance(v, (int, float)) == False) and ((type(v) is np.ndarray) == False):
        raise Exception('Velocity input must be a float or a 1D numpy array')

    # If it is an integer then run the equation normally
    if isinstance(v, (int, float)):
        w = motorW(v, rover)
        tau = tau_dcmotor(w, rover['wheel_assembly']['motor'])
        P = w * tau

    # If the input is a vector then iterate over the entire vector
    if type(v) is np.ndarray:
        if np.ndim(v) == 1:
            P = np.array([])
            for i in range(len(v)):
                w_val = motorW(v[i], rover)
                tau_val = tau_dcmotor(w_val, rover['wheel_assembly']['motor'])
                P_val = w_val * tau_val
                P = np.append(P, P_val)

        else:
            raise Exception('Velocity Input must be a 1D array')

    return P

def battenergy(t, v, rover):   # Do a check for if all values of t and v are numerical
    # Check if rover is a dictionary
    if (type(rover) is dict) == False:
        raise Exception('Rover input must be a dictionary')

    # Check if v input is a vector
    if ((type(v) is np.ndarray) == False):
        raise Exception('Velocity input must be a 1D numpy array')

    # Check if v input is a vector
    if ((type(t) is np.ndarray) == False):
        raise Exception('Time input must be a a 1D numpy array')

    # Check if v and t are 1D arrays
    if np.ndim(v) != 1:
        raise Exception('Vector function must be a 1D array')
    if np.ndim(t) != 1:
        raise Exception('Time function must be a 1D array')

    # Checks if v and t are the same size
    if len(v) != len(t):
        raise Exception('Time and Velocity vectors must be the same size')

    tau_list = rover['wheel_assembly']['motor']['effcy_tau']
    effcy_list = rover['wheel_assembly']['motor']['effcy']
    effcy_fun = sci.interp1d(tau_list, effcy_list, kind='cubic', fill_value='extrapolate')

    power = np.array([])
    for i in range(len(v)):
        if (isinstance(v[i].item(), (int, float)) == False):
            raise Exception('All values in v must be numerical')
        if (isinstance(t[i].item(), (int, float)) == False):
            raise Exception('All values in t must be numerical')
        vel = v[i].item()
        p = mechpower(vel, rover)
        w = motorW(vel, rover)
        tau = tau_dcmotor(w, rover['wheel_assembly']['motor'])
        eff = effcy_fun(tau)
        p_val = 6 * (p/eff)
        power = np.append(power, p_val)

    energy = trapz(power, t)

    return energy


def end_of_mission_event(end_event):
    """
    Defines an event that terminates the mission simulation. Mission is over
    when rover reaches a certain distance, has moved for a maximum simulation
    time or has reached a minimum velocity.
    """

    mission_distance = end_event['max_distance']
    mission_max_time = end_event['max_time']
    mission_min_velocity = end_event['min_velocity']

    # Assume that y[1] is the distance traveled
    distance_left = lambda t, y: mission_distance - y[1]
    distance_left.terminal = True

    time_left = lambda t, y: mission_max_time - t
    time_left.terminal = True

    velocity_threshold = lambda t, y: y[0] - mission_min_velocity;
    velocity_threshold.terminal = True
    velocity_threshold.direction = -1

    # terminal indicates whether any of the conditions can lead to the
    # termination of the ODE solver. In this case all conditions can terminate
    # the simulation independently.

    # direction indicates whether the direction along which the different
    # conditions is reached matters or does not matter. In this case, only
    # the direction in which the velocity treshold is arrived at matters
    # (negative)

    events = [distance_left, time_left, velocity_threshold]

    return events


def simulate_rover(rover, planet, experiment, end_event):
    # Check if rover is a dictionary
    if (type(rover) is dict) == False:
        raise Exception('Rover input must be a dictionary')

    # Check if planet is a dictionary
    if (type(planet) is dict) == False:
        raise Exception('Planet input must be a dictionary')

    # Check if experiment is a dictionary
    if (type(experiment) is dict) == False:
        raise Exception('Experiment input must be a dictionary')

    # Check if end_event is a dictionary
    if (type(end_event) is dict) == False:
        raise Exception('end_event input must be a dictionary')

    dydt = lambda t, y: rover_dynamics(t, y, rover, planet, experiment)

    y0 = experiment['initial_conditions']

    t_end = end_event['max_time']
    tspan = experiment['time_range']

    telemetry = solve_ivp(dydt, tspan, y0, method='RK45', events=end_of_mission_event(end_event))
    t= telemetry.t  # Array of time points
    v = telemetry.y[0]  # Array of velocity points
    x = telemetry.y[1]  # Array of position points
    dist = x[-1]
    com_time = t[-1]
    max_vel = np.max(v)
    ave_vel = np.sum(v)/len(v)
    power = mechpower(v, rover)
    energy = battenergy(t, v, rover)
    energy_dist = energy/dist

    rover['telemetry']['Time'] = t
    rover['telemetry']['completion_time'] = com_time
    rover['telemetry']['velocity'] = v
    rover['telemetry']['position'] = x
    rover['telemetry']['distance_traveled'] = dist
    rover['telemetry']['max_velocity'] = max_vel
    rover['telemetry']['average_velocity'] = ave_vel
    rover['telemetry']['power'] = power
    rover['telemetry']['battery_energy'] = energy
    rover['telemetry']['energy_per_distance'] = energy_dist

    return rover